feat(nvidia): add ntops rms norm backend#616
Conversation
| option(WITH_TORCH "Enable PyTorch C++ backend" OFF) | ||
|
|
||
| option(WITH_NINETOOTHED "Enable NineToothed-generated NVIDIA kernels" OFF) | ||
| set(NINETOOTHED_PYTHON_EXECUTABLE "" CACHE FILEPATH "Python executable used to run ninetoothed code generation") |
There was a problem hiding this comment.
这部分主要是用来写 option 的,请把下面这堆 set 给挪到一个专门的 section。
There was a problem hiding this comment.
已改。WITH_NINETOOTHED 仍然放在 option 区,下面这些 cache 变量已经挪到单独的 NineToothed code generation configuration section 里。
| _SUPPORTED_OPS = ("rms_norm",) | ||
|
|
||
|
|
||
| def _import_ninetoothed(source_dir): |
There was a problem hiding this comment.
为啥不直接是 import ninetoothed?且请不要使用缩写,直接使用全称 ninetoothed。
There was a problem hiding this comment.
已改。现在实现移到了 src/native/ninetoothed/codegen.py,_import_ninetoothed 只在可选 source dir 需要时调整 sys.path,随后直接 import ninetoothed,变量名也不再用 nt 缩写。
| return nt | ||
|
|
||
|
|
||
| def _import_ntops(): |
There was a problem hiding this comment.
为啥不直接是 import ntops?
There was a problem hiding this comment.
已改。_rms_norm_premake 里直接 import ntops,不再通过 importlib 或 _import_ntops 包一层。
|
|
||
| _DEFAULT_DTYPES = ("float32", "float16", "bfloat16") | ||
|
|
||
| _DEFAULT_RMS_NORM_SHAPES = ( |
There was a problem hiding this comment.
具体算子相关的内容,应该放到 src 里,scripts 里面只放纯功能性工具或者构建相关脚本。
There was a problem hiding this comment.
已改。scripts/generate_ninetoothed_ops.py 现在只作为构建入口,把 src 加到 sys.path 后委托给 native.ninetoothed.codegen.main();具体算子和生成逻辑放到了 src/native/ninetoothed/codegen.py。
| return importlib.import_module("ntops") | ||
|
|
||
|
|
||
| def _rms_norm_premake_rank2(dim0, dim1, dtype, block_size): |
There was a problem hiding this comment.
同上,这部分应当放在 src 中合适的地方,而不是在 scripts 下。以下同类问题不再赘述,但请一并修改。
There was a problem hiding this comment.
已一并修改。RmsNorm 的 premake 包装、rank/dtype config、manifest 生成都挪到了 src/native/ninetoothed/codegen.py,scripts 下不再放算子细节。
| return arrangement, application, tensors | ||
|
|
||
|
|
||
| def _rms_norm_premake_rank3(dim0, dim1, dim2, dtype, block_size): |
There was a problem hiding this comment.
这为啥要分 rank?不是只有 shape 不一样,那不是传个 shape 就行了嘛?
There was a problem hiding this comment.
已改。之前按 rank 拆函数是为了让 ninetoothed.build 生成不同 launcher 参数;现在改成使用 ntops 自带的动态-rank premake,只按 ndim/dtype 生成同一个 infiniops_ninetoothed_rms_norm dispatcher,Python 侧不再拆 rank2/rank3 premake。
| return arrangement, application, tensors | ||
|
|
||
|
|
||
| def _parse_shape(value): |
There was a problem hiding this comment.
已删除。现在不再按具体 shape 编译,也不需要解析 1x64 这类 shape 字符串;配置改为 INFINIOPS_NINETOOTHED_RMS_NORM_NDIMS / --rms-norm-ndims。
|
|
||
| namespace detail { | ||
|
|
||
| inline int NineToothedRmsNormDTypeIndex(DataType dtype) { |
There was a problem hiding this comment.
这种类似的 helper 不是应该是整个九齿 common 的嘛?不要放在 rms_norm 下面。
There was a problem hiding this comment.
已改。DTypeIndex、SizeArg、FromTensor、FromScalar 都放到 src/native/ninetoothed/tensor.h 作为九齿 common helper;rms_norm/ninetoothed.h 只保留 ExpandedRmsNormWeight 这种算子特有适配和 generated launcher 调用。
fa89de9 to
0ad2354
Compare
| _SUPPORTED_OPS = ("rms_norm",) | ||
|
|
||
|
|
||
| def _import_ninetoothed(source_dir): |
There was a problem hiding this comment.
到底为啥需要这个 helper?去掉它,直接在 top-level import ninetoothed 就行。
|
|
||
| option(WITH_TORCH "Enable PyTorch C++ backend" OFF) | ||
|
|
||
| option(WITH_NINETOOTHED "Enable NineToothed-generated NVIDIA kernels" OFF) |
There was a problem hiding this comment.
此处不提及 NVIDIA,因为九齿的目标是跨平台,只是目前可能只暴露了 cuda caller,所以跟 PyTorch 的对齐即可。
There was a problem hiding this comment.
九齿的定位在算子库中应该跟 PyTorch 差不多,都可以接入到后端里,所以在文件结构上应该跟 PyTorch 平行,而不是放在 cuda 下,现在的九齿可能只有 cuda 这个 caller,但是生成的接口是一致的,只要后期增多了支持,就可以跨平台,跟 PyTorch 一样。
| #ifndef INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_ | ||
| #define INFINI_OPS_NVIDIA_RMS_NORM_NINETOOTHED_H_ | ||
|
|
||
| #ifdef WITH_NINETOOTHED |
There was a problem hiding this comment.
就像评论 https://github.com/InfiniTensor/InfiniOps/pull/616/changes#r3285560677 所说,九齿应该是与 PyTorch 等对应的后端,所以是跟 torch 差不多的文件架构,而咱们算子库都是靠 build system 和脚本来确定最终产物,所以不要在 src 里面的文件使用 WITH_NINETOOTHED 这种类似的宏。事实上,在 C++ 中,我们应当尽量少地使用宏。
| #include "rms_norm/infiniops_ninetoothed_rms_norm.h" | ||
|
|
||
| #ifndef INFINIOPS_NINETOOTHED_BLOCK_SIZE | ||
| #define INFINIOPS_NINETOOTHED_BLOCK_SIZE 256 |
There was a problem hiding this comment.
C++ 中尽量不使用宏,尤其是这种可以被 constexpr 或者 const 替代的情况。
| }; | ||
| } | ||
|
|
||
| template <typename NineToothedTensor, typename T> |
There was a problem hiding this comment.
这个模板参数的意义是什么?暂时是冗余的。如无必要,吴增实体。
There was a problem hiding this comment.
已处理:删掉了原来的 FromScalar<NineToothedTensor> 函数模板,标量也统一通过 ninetoothed::Tensor(value, empty_shape, empty_strides) 包装后传给生成 launcher。
There was a problem hiding this comment.
我这里有一套方案,可以看看是不是更好一些:我们直接提供一个 infini::ops::ninetoothed::Tensor 类,这个类里面去定义从 infini::ops::Tensor 或者 scalar 到它的 implicit conversion。这样用起来会不会更方便一些?也请考虑一些可能的风险,综合评判是否这么做。
There was a problem hiding this comment.
已按这个方向调整:现在提供 infini::ops::ninetoothed::Tensor,负责从 InfiniOps Tensor、标量或自定义 shape/stride 视图适配到 NineToothed launcher 参数。一个实现上的取舍是:公共头不直接依赖某个 op 生成出的 NineToothedTensor 定义,因为这个类型来自生成 header;因此转换在调用点按目标参数类型延迟实例化。这样可以保留 ninetoothed::Tensor(input) 的简洁使用方式,同时避免 include 顺序和 op-specific 生成头泄漏到公共适配层。
|
|
||
| set(${out_var} "" PARENT_SCOPE) | ||
| endfunction() | ||
| # NineToothed code generation configuration. |
There was a problem hiding this comment.
请在后面创建一个关于 WITH_NINETOOTHED 的 if 吧,把这些 set 放到这个分支里吧。
There was a problem hiding this comment.
已处理:NINETOOTHED_* 和 INFINIOPS_NINETOOTHED_* cache 配置现在都放在后面的 if(WITH_NINETOOTHED) 分支里,默认不开启时不再提前暴露这些配置项。
| import sys | ||
| import tempfile | ||
| import types | ||
| import unittest |
There was a problem hiding this comment.
为什么引入了 unittest?请统一使用 pytest,与其他测试保持一致。
There was a problem hiding this comment.
已处理:测试已改成 pytest 风格的普通测试函数,使用 monkeypatch,不再引入 unittest 或 unittest.mock。
|
|
||
| namespace infini::ops { | ||
|
|
||
| namespace detail { |
There was a problem hiding this comment.
我看此处 detail 内部的函数比较少,且每个函数的内容也很少,可以考虑直接放在 Operator<RmsNorm, Device::Type::kNvidia, 9>::operator() 里面,暂时不单独抽成独立的 helper 了。
There was a problem hiding this comment.
已处理:去掉了 detail namespace 中的两个小 helper,把 weight 扩展、dtype index 和 generated launcher 调用都放回 Operator<RmsNorm, Device::Type::kNvidia, 9>::operator() 内。当前逻辑比较短,内联后更直接。
0ad2354 to
eff11f2
Compare
@/tmp/pr616-body.md